# -*- coding: utf-8 -*-
"""
Created on Tue Dec 21 22:21:17 2021

@author: LSY
"""

from __future__ import division
import numpy as np
import mpctools as mpc
import scipy.io as spio
import casadi
import timeit

np.set_printoptions(precision=4)
np.set_printoptions(linewidth=400)
from Sens_proj1 import SenMatrix, SensAnal
from model_wwtp78 import ode_wwtp_scaled

rank_xp = True
Nx = 78
Nu = 16
Nymax = Nx
order = Nx # Number of Lie derivative from 0 to order-1
Nw = Nx 
Ny = 16
Nv = Ny
Delta = 15/24/60 # Sampling time 15min
Nsim =20 # Number of sampling time

x = casadi.SX.sym('x',Nx,1) 
u = casadi.SX.sym('u',Nu,1) 


F = mpc.getCasadiFunc(ode_wwtp_scaled,[Nx,Nu],["x","u"],"cstrmodel")
F_rk4 = mpc.getCasadiFunc(ode_wwtp_scaled,[Nx,Nu],["x","u"],
                                     "ode_rk4",rk4=True,Delta=Delta,M=1)

wwtp = mpc.DiscreteSimulator(ode_wwtp_scaled, Delta, [Nx,Nu], ["x","u"])

sigma_w = 0.01
sigma_v = 0.01
w = sigma_w*np.random.randn(Nsim,Nw)
v = sigma_v*np.random.randn(Nsim,Nv)

data = np.loadtxt('Inf_dry_2006.txt')
KLa5 = 84
Qa = 55338
Q0 = data[0:Nsim+1,14]
Z0 = data[0:Nsim+1,1:14]
usim = np.zeros((Nsim+1,Nu))
usim[:,0] = KLa5 
usim[:,1] = Qa
usim[:,2] = Q0
usim[:,3:16] = Z0

xsim = np.zeros((Nsim+1,Nx))
xsim[0,:] = np.ones(78)*0.5

for t in range(Nsim):
    xsim[t+1,:] = wwtp.sim(xsim[t,:],usim[t,:]) + w[t,:]

print("Calculate the sensitive matrix with all measurements ......")
Tstart = timeit.default_timer()

M_1 = list(range(Ny))
M_all = M_1
print("List all the possible measurement combos:\n ", M_all )
ysim = np.zeros((Nsim, Ny))
def measurement_all(x): #Zeng's paper uses reactor 3,but MATLAB code use reactor 5
    C_extend = np.zeros([Ny,Nx])
#    C_extend[0:Nx,0:Nx] = np.eye(Nx)
    C_extend[0,20] = np.array([1])     #SO2,3,4,5
    C_extend[1,33] = np.array([1])
    C_extend[2,46] = np.array([1])
    C_extend[3,59] = np.array([1])
    C_extend[4,[13,14,15,16,17,18]] = np.array([1,1,1,1,1,1])  # COD2
    C_extend[5,13*4:13*4+6] = np.array([1,1,1,1,1,1])          # COD3/5
    C_extend[6,[65,66,67,68,71,72]] = np.array([1,1,1,1,1,1])  # COD6
    C_extend[7,[13,14]] = np.array([1,1]) # CDOf2
    C_extend[8,[13*4,13*4+1]] = np.array([1,1]) # CODf3/5
    C_extend[9,25] = np.array([1])        # SALK2
    C_extend[10,38] = np.array([1])       # SALK3  
    C_extend[11,77] = np.array([1])        # SALK6
    C_extend[12,[15,16,17,18,19,24]] = np.array([1,1,1,1,1,1]) # TSS2
#    C_extend[13,[28,29,30,31,32,37]] = np.array([1,1,1,1,1,1]) # TSS3
    C_extend[13,[54,55,56,57,58,63]] = np.array([1,1,1,1,1,1]) # TSS5
    C_extend[14,[73]] = np.array([1])     # SNO6
    C_extend[15,[74]] = np.array([1])     # SNH6
    y = np.dot(C_extend,x)
#    mat_measure = spio.loadmat('C_48.mat', squeeze_me=True)# 1--6 [SO SNO SNH SALK COD TSS CODf BOD] 
#    C_extend = mat_measure['C_145']
#    y = np.dot(C_extend,x)
    return y


for t in range(Nsim):
    ysim[t,:] = measurement_all(xsim[t,:])+v[t,:] # Get zero-noise measurement.
H = mpc.getCasadiFunc(measurement_all,[Nx],["x"],"measurement")

SM_all = SenMatrix(F_rk4,H,xsim[0:Nsim,:],ysim[0:Nsim,:],usim[0:Nsim-1,:],Nsim-1) 
Tend = timeit.default_timer()
print("The sensitive matrix SM = \n",SM_all)
print('Cost time:', str(Tend-Tstart),'s')
####################################################################

####The sensitive degree with all the measurments:#####
##### initial
Tstart = timeit.default_timer()
Flag = SensAnal(SM_all,Nx,Ny,sigma_w,sigma_v,Nsim-1,rank_xp)   
Rank_Sen = Flag[0]
print('Rank of Sensiti:', Rank_Sen)
Sen_deg = Flag[-1]
print("Sen degree:",Sen_deg)
Tend = timeit.default_timer()
print('Cost time:', str(Tend-Tstart),'s')
########################################################



###########if these measurements have been removed:
M_remain = list(range(Ny))
R_known = []  #list(range(140))
N_known = 0

print("These measurements have been removed priori:\n", R_known)
for i in R_known:
    M_remain.remove(i)
Ny = Ny - N_known
Nr_max = Nx - N_known
print("Remain these measurements:\n", M_remain)
####################################################


###############################begin remove ########################################################
print("------------------------------------------Begin Remove-------------------------------------")

for k in range(Nr_max):
    print(k+N_known+1,"- th remove", "(remian:", Nx-N_known,"measurements)")
    Rank_Sen_remain_max = np.zeros((1,1))
    M_remain_try = []
    M_remain_copy = []    
    R_known_try = []
    Ny = Ny - 1
    Sen_deg_max = 0
    for i in M_remain:
        Tstart = timeit.default_timer()
        R_known_try[:] = R_known
        R_known_try.append(i)
        M_remove_try = i;
        REs = open("Results.txt", "a+")
        print('Try to remove:', M_remove_try)
        print('Try to remove:', M_remove_try, file = REs)
        M_remain_try[:] = M_remain
        M_remain_try.remove(i)
        print('Remian these measurements', M_remain_try)
        print('Remian these measurements', M_remain_try, file = REs)
        
        TT = np.eye(Nsim*(16))
        T_list = []
        
        for ni in range(Nsim):        # 删除Nsim个y
            for nj in R_known_try:
                T_list.append(nj+ni*(16))
        TT = np.delete(TT,T_list,axis=0)       
        SM_I = np.dot(TT,SM_all)

        Flag = SensAnal(SM_I,Nx,Ny,sigma_w,sigma_v,Nsim-1,rank_xp)
        
        Rank_Sen_remain = Flag[0]
        print('Rank of Sensiti:', Rank_Sen_remain)
        print('Rank of Sensiti:', Rank_Sen_remain, file = REs)
#        Sen_state_sort = Flag[1:-1]
#        print("Sen state sort:",Sen_state_sort)
        Sen_deg = Flag[-1]
        print("Sen degree:",Sen_deg)
        print("Sen degree:",Sen_deg, file = REs)
        if Rank_Sen_remain_max < Rank_Sen_remain:
            Rank_Sen_remain_max[0,0] = Rank_Sen_remain
            
        if Sen_deg > Sen_deg_max:
           Sen_deg_max = Sen_deg
           M_remove = M_remove_try
           M_remain_copy[:] = M_remain_try
           
        Tend = timeit.default_timer()
        print('Cost time:', str(Tend-Tstart),'s')
        print("-----------------------------------------------------------------------------------")
        print("-----------------------------------------------------------------------------------", file = REs)
        
    if Rank_Sen_remain_max < Nx:
        print("stop trying")
        print("stop trying", file = REs)
        break
    else:
        print("new removed:",M_remove)
        print("new removed:",M_remove)
        M_remain = M_remain_copy
        R_known.append(M_remove)
        print("removed measurements:\n", R_known)
        print("removed measurements:\n", R_known, file = REs)
        print("remians measurements:\n", M_remain)   
        print("remians measurements:\n", M_remain, file = REs)  
print("All removed measurements:\n", R_known)
print("All removed measurements:\n", R_known, file = REs)
REs.close()
print("Finially remians:",M_remain)      